1#![allow(dead_code)]
12
13use std::io::Read;
14use std::path::{Component, Path, PathBuf};
15
16use path_clean::PathClean;
17use thiserror::Error;
18
19#[derive(Debug, Error)]
21pub enum Error {
22 #[error("tarball path traversal attempt: {entry_path}")]
25 Escape { entry_path: PathBuf },
26
27 #[error("unsupported tarball entry type '{kind}': {entry_path}")]
30 UnsupportedEntry {
31 entry_path: PathBuf,
32 kind: &'static str,
33 },
34
35 #[error("pre-existing symlink in target tree: {path}")]
38 PrePlantedSymlink { path: PathBuf },
39
40 #[error(transparent)]
42 Io(#[from] std::io::Error),
43}
44
45pub fn extract_safe(source: impl Read, target: &Path) -> Result<(), Error> {
66 let canon_target = target.canonicalize()?;
68
69 check_ancestors_not_symlinks(&canon_target)?;
71
72 check_children_not_symlinks(&canon_target)?;
74
75 let gz = flate2::read::GzDecoder::new(source);
77 let mut archive = tar::Archive::new(gz);
78
79 for entry_result in archive.entries()? {
80 let mut entry = entry_result?;
81 let header = entry.header();
82
83 let entry_type = header.entry_type();
85 if !entry_type.is_file() && !entry_type.is_dir() {
86 let kind = classify_entry_type(entry_type);
87 let entry_path = entry.path()?.into_owned();
88 return Err(Error::UnsupportedEntry { entry_path, kind });
89 }
90
91 let raw_path = entry.path()?.into_owned();
93
94 if raw_path.is_absolute() {
96 return Err(Error::Escape {
97 entry_path: raw_path,
98 });
99 }
100
101 for component in raw_path.components() {
103 match component {
104 Component::ParentDir | Component::RootDir | Component::Prefix(_) => {
105 return Err(Error::Escape {
106 entry_path: raw_path,
107 });
108 }
109 _ => {}
110 }
111 }
112
113 let resolved = canon_target.join(&raw_path).clean();
115 if !starts_with_normalized(&resolved, &canon_target) {
116 return Err(Error::Escape {
117 entry_path: raw_path,
118 });
119 }
120
121 entry.unpack_in(&canon_target)?;
123 }
124
125 Ok(())
126}
127
128fn check_ancestors_not_symlinks(path: &Path) -> Result<(), Error> {
130 let mut current = path.to_path_buf();
131 while let Some(parent) = current.parent() {
132 if parent.as_os_str().is_empty() {
133 break;
134 }
135 if parent.symlink_metadata()?.file_type().is_symlink() {
136 return Err(Error::PrePlantedSymlink {
137 path: parent.to_path_buf(),
138 });
139 }
140 current = parent.to_path_buf();
141 }
142 Ok(())
143}
144
145fn check_children_not_symlinks(path: &Path) -> Result<(), Error> {
147 if !path.is_dir() {
148 return Ok(());
149 }
150 for entry in std::fs::read_dir(path)? {
151 let entry = entry?;
152 if entry.metadata()?.file_type().is_symlink() {
153 return Err(Error::PrePlantedSymlink { path: entry.path() });
154 }
155 }
156 Ok(())
157}
158
159fn starts_with_normalized(path: &Path, prefix: &Path) -> bool {
162 #[cfg(windows)]
163 {
164 let path_lower = path.to_string_lossy().to_lowercase();
165 let prefix_lower = prefix.to_string_lossy().to_lowercase();
166 path_lower.starts_with(&prefix_lower)
167 }
168 #[cfg(not(windows))]
169 {
170 path.starts_with(prefix)
171 }
172}
173
174fn classify_entry_type(t: tar::EntryType) -> &'static str {
176 match t {
177 tar::EntryType::Symlink => "symlink",
178 tar::EntryType::Link => "hardlink",
179 tar::EntryType::Char => "char device",
180 tar::EntryType::Block => "block device",
181 tar::EntryType::Fifo => "fifo",
182 tar::EntryType::GNUSparse => "sparse",
183 _ => "unknown",
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use std::io::Cursor;
191 use tempfile::TempDir;
192
193 fn build_tar_gz(entries: &[(&str, &[u8], bool)]) -> Vec<u8> {
196 let mut builder = tar::Builder::new(Vec::new());
197 for &(path, content, is_dir) in entries {
198 if is_dir {
199 let mut header = tar::Header::new_gnu();
200 header.set_entry_type(tar::EntryType::Directory);
201 header.set_size(0);
202 header.set_mode(0o755);
203 header.set_cksum();
204 builder
205 .append_data(&mut header, path, &[] as &[u8])
206 .unwrap();
207 } else {
208 let mut header = tar::Header::new_gnu();
209 header.set_entry_type(tar::EntryType::Regular);
210 header.set_size(content.len() as u64);
211 header.set_mode(0o644);
212 header.set_cksum();
213 builder.append_data(&mut header, path, content).unwrap();
214 }
215 }
216 let tar_bytes = builder.into_inner().unwrap();
217
218 use flate2::write::GzEncoder;
220 use flate2::Compression;
221 use std::io::Write;
222 let mut encoder = GzEncoder::new(Vec::new(), Compression::fast());
223 encoder.write_all(&tar_bytes).unwrap();
224 encoder.finish().unwrap()
225 }
226
227 fn build_tar_gz_with_symlink(link_name: &str, target: &str) -> Vec<u8> {
229 let mut builder = tar::Builder::new(Vec::new());
230 let mut header = tar::Header::new_gnu();
231 header.set_entry_type(tar::EntryType::Symlink);
232 header.set_size(0);
233 header.set_mode(0o777);
234 header.set_cksum();
235 builder.append_link(&mut header, link_name, target).unwrap();
236 let tar_bytes = builder.into_inner().unwrap();
237
238 use flate2::write::GzEncoder;
239 use flate2::Compression;
240 use std::io::Write;
241 let mut encoder = GzEncoder::new(Vec::new(), Compression::fast());
242 encoder.write_all(&tar_bytes).unwrap();
243 encoder.finish().unwrap()
244 }
245
246 fn build_tar_gz_with_hardlink(link_name: &str, target: &str) -> Vec<u8> {
248 let mut builder = tar::Builder::new(Vec::new());
249 let mut header = tar::Header::new_gnu();
250 header.set_entry_type(tar::EntryType::Link);
251 header.set_size(0);
252 header.set_mode(0o644);
253 header.set_cksum();
254 builder.append_link(&mut header, link_name, target).unwrap();
255 let tar_bytes = builder.into_inner().unwrap();
256
257 use flate2::write::GzEncoder;
258 use flate2::Compression;
259 use std::io::Write;
260 let mut encoder = GzEncoder::new(Vec::new(), Compression::fast());
261 encoder.write_all(&tar_bytes).unwrap();
262 encoder.finish().unwrap()
263 }
264
265 fn build_tar_gz_with_forged_path(malicious_path: &str) -> Vec<u8> {
273 let placeholder = "PLACEHOLDER_PATH_FOR_FORGING";
275 assert!(
276 malicious_path.len() <= placeholder.len(),
277 "malicious path too long for placeholder"
278 );
279
280 let content = b"evil";
281 let mut builder = tar::Builder::new(Vec::new());
282 let mut header = tar::Header::new_gnu();
283 header.set_entry_type(tar::EntryType::Regular);
284 header.set_size(content.len() as u64);
285 header.set_mode(0o644);
286 header.set_cksum();
287 builder
288 .append_data(&mut header, placeholder, &content[..])
289 .unwrap();
290 let mut tar_bytes = builder.into_inner().unwrap();
291
292 if let Some(pos) = tar_bytes
295 .windows(placeholder.len())
296 .position(|w| w == placeholder.as_bytes())
297 {
298 for b in &mut tar_bytes[pos..pos + placeholder.len()] {
300 *b = 0;
301 }
302 tar_bytes[pos..pos + malicious_path.len()].copy_from_slice(malicious_path.as_bytes());
304
305 let header_start = pos - (pos % 512); let mut sum: u64 = 0;
310 for (i, &b) in tar_bytes[header_start..header_start + 512]
311 .iter()
312 .enumerate()
313 {
314 if (148..156).contains(&i) {
315 sum += 0x20u64; } else {
317 sum += b as u64;
318 }
319 }
320 let cksum = format!("{sum:06o}\0 ");
321 tar_bytes[header_start + 148..header_start + 156].copy_from_slice(cksum.as_bytes());
322 }
323
324 use flate2::write::GzEncoder;
326 use flate2::Compression;
327 use std::io::Write;
328 let mut encoder = GzEncoder::new(Vec::new(), Compression::fast());
329 encoder.write_all(&tar_bytes).unwrap();
330 encoder.finish().unwrap()
331 }
332
333 #[test]
336 fn extracts_regular_files_to_target() {
337 let tmp = TempDir::new().unwrap();
338 let target = tmp.path().join("dest");
339 std::fs::create_dir_all(&target).unwrap();
340
341 let gz = build_tar_gz(&[
342 ("core/", b"", true),
343 ("core/manifest.json", b"{\"version\":\"0.1\"}", false),
344 ]);
345
346 extract_safe(Cursor::new(gz), &target).unwrap();
347
348 let manifest = target.join("core/manifest.json");
349 assert!(manifest.exists(), "manifest.json should exist");
350 assert_eq!(
351 std::fs::read_to_string(manifest).unwrap(),
352 "{\"version\":\"0.1\"}"
353 );
354 }
355
356 #[test]
357 fn extracts_nested_directories() {
358 let tmp = TempDir::new().unwrap();
359 let target = tmp.path().join("dest");
360 std::fs::create_dir_all(&target).unwrap();
361
362 let gz = build_tar_gz(&[
363 ("core/", b"", true),
364 ("core/skills/", b"", true),
365 ("core/skills/query-installation/", b"", true),
366 ("core/skills/query-installation/SKILL.md", b"# Query", false),
367 ]);
368
369 extract_safe(Cursor::new(gz), &target).unwrap();
370
371 let file = target.join("core/skills/query-installation/SKILL.md");
372 assert!(file.exists());
373 assert_eq!(std::fs::read_to_string(file).unwrap(), "# Query");
374 }
375
376 #[test]
377 fn empty_tarball_succeeds() {
378 let tmp = TempDir::new().unwrap();
379 let target = tmp.path().join("dest");
380 std::fs::create_dir_all(&target).unwrap();
381
382 let gz = build_tar_gz(&[]);
383 extract_safe(Cursor::new(gz), &target).unwrap();
384 }
386
387 #[test]
390 fn rejects_dotdot_path_traversal() {
391 let tmp = TempDir::new().unwrap();
392 let target = tmp.path().join("dest");
393 std::fs::create_dir_all(&target).unwrap();
394
395 let gz = build_tar_gz_with_forged_path("../../etc/passwd");
396 let err = extract_safe(Cursor::new(gz), &target).unwrap_err();
397 assert!(
398 matches!(err, Error::Escape { .. }),
399 "Expected Escape, got: {err:?}"
400 );
401 }
402
403 #[test]
404 fn rejects_absolute_path() {
405 let tmp = TempDir::new().unwrap();
406 let target = tmp.path().join("dest");
407 std::fs::create_dir_all(&target).unwrap();
408
409 let gz = build_tar_gz_with_forged_path("/tmp/evil");
410 let err = extract_safe(Cursor::new(gz), &target).unwrap_err();
411 assert!(
412 matches!(err, Error::Escape { .. }),
413 "Expected Escape, got: {err:?}"
414 );
415 }
416
417 #[test]
420 fn rejects_symlink_entry() {
421 let tmp = TempDir::new().unwrap();
422 let target = tmp.path().join("dest");
423 std::fs::create_dir_all(&target).unwrap();
424
425 let gz = build_tar_gz_with_symlink("evil-link", "/etc/passwd");
426 let err = extract_safe(Cursor::new(gz), &target).unwrap_err();
427 match err {
428 Error::UnsupportedEntry { kind, .. } => assert_eq!(kind, "symlink"),
429 other => panic!("Expected UnsupportedEntry(symlink), got: {other:?}"),
430 }
431 }
432
433 #[test]
434 fn rejects_hardlink_entry() {
435 let tmp = TempDir::new().unwrap();
436 let target = tmp.path().join("dest");
437 std::fs::create_dir_all(&target).unwrap();
438
439 let gz = build_tar_gz_with_hardlink("evil-link", "core/manifest.json");
440 let err = extract_safe(Cursor::new(gz), &target).unwrap_err();
441 match err {
442 Error::UnsupportedEntry { kind, .. } => assert_eq!(kind, "hardlink"),
443 other => panic!("Expected UnsupportedEntry(hardlink), got: {other:?}"),
444 }
445 }
446
447 #[test]
454 fn rejects_pre_planted_symlink_child() {
455 let tmp = TempDir::new().unwrap();
456 let target = tmp.path().join("dest");
457 std::fs::create_dir_all(&target).unwrap();
458
459 let decoy = tmp.path().join("decoy");
460 std::fs::create_dir_all(&decoy).unwrap();
461
462 let symlink_path = target.join("evil");
464 #[cfg(unix)]
465 std::os::unix::fs::symlink(&decoy, &symlink_path).unwrap();
466 #[cfg(windows)]
467 {
468 if std::os::windows::fs::symlink_dir(&decoy, &symlink_path).is_err() {
469 eprintln!("Skipping: symlink creation requires Developer Mode");
472 return;
473 }
474 }
475
476 let gz = build_tar_gz(&[("core/manifest.json", b"data", false)]);
477 let err = extract_safe(Cursor::new(gz), &target).unwrap_err();
478 assert!(
479 matches!(err, Error::PrePlantedSymlink { .. }),
480 "Expected PrePlantedSymlink, got: {err:?}"
481 );
482 }
483
484 #[test]
487 fn escape_error_names_path() {
488 let err = Error::Escape {
489 entry_path: PathBuf::from("../../etc/passwd"),
490 };
491 let display = format!("{err}");
492 assert!(display.contains("../../etc/passwd"));
493 }
494
495 #[test]
496 fn unsupported_entry_names_kind_and_path() {
497 let err = Error::UnsupportedEntry {
498 entry_path: PathBuf::from("evil-link"),
499 kind: "symlink",
500 };
501 let display = format!("{err}");
502 assert!(display.contains("symlink"));
503 assert!(display.contains("evil-link"));
504 }
505
506 #[test]
507 fn pre_planted_symlink_error_names_path() {
508 let err = Error::PrePlantedSymlink {
509 path: PathBuf::from("/some/path"),
510 };
511 let display = format!("{err}");
512 assert!(display.contains("/some/path"));
513 }
514}